#ifndef XENENGIN_XEN_RENDERER_BUFFER_HPP
#define XENENGIN_XEN_RENDERER_BUFFER_HPP

#include "xen/pch.hpp"
#include "xen/core.hpp"

namespace Xen {

enum class ShaderDataType {
    None = 0,
    Float, Float2, Float3, Float4,
    Mat3, Mat4,
    Int, Int2, Int3, Int4,
    Bool,
};

struct BufferElement {
    std::string name;
    ShaderDataType type;
    uint32_t offset;
    uint32_t size;
    bool normalized;

    static uint32_t ShaderDataTypeSize(ShaderDataType type) {
        switch (type)  {
            case ShaderDataType::None:      return 0;
            case ShaderDataType::Float:     return 4;
            case ShaderDataType::Float2:    return 4 * 2;
            case ShaderDataType::Float3:    return 4 * 3;
            case ShaderDataType::Float4:    return 4 * 4;
            case ShaderDataType::Int:       return 4 ;
            case ShaderDataType::Int2:      return 4 * 2;
            case ShaderDataType::Int3:      return 4 * 3;
            case ShaderDataType::Int4:      return 4 * 4;
            case ShaderDataType::Bool:      return 1;
            case ShaderDataType::Mat3:      return 4 * 3 * 3;
            case ShaderDataType::Mat4:      return 4 * 4 * 4;
        }
        XEN_CORE_ASSERT(false, "Unknown shader data type");
        return 0;
    }

    BufferElement(const std::string& name, ShaderDataType type, bool normalized = false)
        :name(name), type(type), size(ShaderDataTypeSize(type)), offset(0), normalized(normalized) {
    }

    uint32_t GetComponentCount() const {
        switch (type) {
            case ShaderDataType::Float:     return 1;
            case ShaderDataType::Float2:    return 2;
            case ShaderDataType::Float3:    return 3;
            case ShaderDataType::Float4:    return 4;
            case ShaderDataType::Mat3:      return 3*3;
            case ShaderDataType::Mat4:      return 4*4;
            case ShaderDataType::Int:       return 1;
            case ShaderDataType::Int2:      return 2;
            case ShaderDataType::Int3:      return 3;
            case ShaderDataType::Int4:      return 4;
            case ShaderDataType::Bool:      return 1;
        }
        XEN_CORE_ASSERT(false, "Unknown element count");
        return 0;
    }

};

class BufferLayout {
public:
    BufferLayout() = default;
    BufferLayout(const std::initializer_list<BufferElement>& elements): elements_(elements) {
        calculateOffsetStride();
    }

    inline const std::vector<BufferElement>& GetElements() const { return elements_; }
    inline uint32_t GetStride() const { return stride_; }

    std::vector<BufferElement>::iterator begin() { return elements_.begin(); }
    std::vector<BufferElement>::iterator end() { return elements_.end(); }

    std::vector<BufferElement>::const_iterator cbegin() const { return elements_.cbegin(); }
    std::vector<BufferElement>::const_iterator cend() const { return elements_.cend(); }

private:
    std::vector<BufferElement> elements_;
    uint32_t stride_ = 0;

    void calculateOffsetStride() {
        uint32_t offset = 0;
        stride_ = 0;
        for (auto& element: elements_) {
            element.offset = offset;
            offset += element.size;
            stride_ += element.size;
        }
    }
};

class VertexBuffer {
public:
    static Ref<VertexBuffer> Create(const void* vertex_data, uint32_t size);
    static Ref<VertexBuffer> Create(uint32_t size);

    virtual ~VertexBuffer() = default;

    virtual void SetLayout(const BufferLayout& layout) = 0;
    virtual const BufferLayout& GetLayout() const = 0;
    virtual void SetData(uint32_t offset, uint32_t size, const void* data) = 0;

    virtual void Bind() const = 0;
    virtual void Unbind() const = 0;
};

class IndexBuffer {
public:
    static Ref<IndexBuffer> Create(const uint32_t* indices, uint32_t count);

    virtual ~IndexBuffer() = default;

    virtual void Bind() const = 0;
    virtual void Unbind() const = 0;

    virtual uint32_t GetCount() const = 0;
};

}

#endif
